# Copyright 2022 Twitter, Inc and Zhendong Wang.
# SPDX-License-Identifier: Apache-2.0

import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.logger import logger

from agents.diffusion import Diffusion
from agents.model import MLP
from agents.helpers import EMA

from agents.ql_diffusion import DiffusionQL


class QuantileGuidedDiffusion(DiffusionQL):
    def __init__(self,
                 state_dim,
                 action_dim,
                 max_action,
                 device,
                 discount,
                 tau,
                 using_guide=True,
                 prob_unconditional=0.1,
                 guide_weight=4.,
                 iql_quantile_point=0.7,
                 critic_policy='behavior',
                 critic_loss_type='expectile',
                 filter_policy='same',
                 filter_loss_type='expectile',
                 imitation_method='filter',
                 balance_condition_reweight=True,
                 balance_sample_reweight=False,
                 n_hard_updating_target_q=None,
                 weight_temperature=1.,
                 max_q_backup=False,
                 eta=1.0,
                 beta_schedule='linear',
                 n_timesteps=100,
                 ema_decay=0.995,
                 step_start_ema=1000,
                 update_ema_every=5,
                 lr=3e-4,
                 lr_decay=False,
                 lr_maxt=1000,
                 grad_norm=1.0,
                 condition_pos_embed=True,
                 training_epochs=None,
                 ):
        super(QuantileGuidedDiffusion, self).__init__(
            state_dim=state_dim,
            action_dim=action_dim,
            max_action=max_action,
            device=device,
            discount=discount,
            tau=tau,
            max_q_backup=max_q_backup,
            eta=eta,
            beta_schedule=beta_schedule,
            n_timesteps=n_timesteps,
            ema_decay=ema_decay,
            step_start_ema=step_start_ema,
            update_ema_every=update_ema_every,
            lr=lr,
            lr_decay=lr_decay,
            lr_maxt=lr_maxt,
            grad_norm=grad_norm,)
        self.using_guide = using_guide
        self.prob_unconditional = prob_unconditional
        self.guide_weight = guide_weight
        self.iql_quantile_point = iql_quantile_point

        self.critic_policy = critic_policy
        self.critic_loss_type = critic_loss_type
        self.filter_policy = filter_policy
        self.filter_loss_type = filter_loss_type

        self.balance_condition_reweight = balance_condition_reweight
        self.using_balance_loss = False
        self.condition_sets = (0, 1)
        self.balance_loss_weight = 1.
        print('-- using_balance_loss', self.using_balance_loss)
        print('-- condition_sets', self.condition_sets)
        print('-- balance_loss_weight', self.balance_loss_weight)

        self.imitation_method = imitation_method
        self.n_hard_updating_target_q = n_hard_updating_target_q
        self.weight_temperature = weight_temperature
        self.training_epochs = training_epochs
        self.balance_sample_reweight = balance_sample_reweight
        self.balance_sample_weight = 1.

        self.n_critic_quantile_points = \
            int(self.critic_loss_type.split('-')[-1]) \
                if self.critic_loss_type in ['quantile-101', 'quantile-51', 'quantile-21', 'quantile-11'] else 1
        if self.imitation_method == 'quantile' or self.imitation_method == 'ge_quantile':
            self.n_filter_quantile_points = int(self.filter_loss_type.split('-')[-1])
            self.condition_running_average = torch.ones([self.n_filter_quantile_points], device=self.device)
            self.condition_running_ratio = 0.999

        if self.using_guide:
            self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device,
                             condition_pos_embed=condition_pos_embed)
            self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action,
                                   using_guide=using_guide, prob_unconditional=prob_unconditional,
                                   guide_weight=guide_weight,
                                   using_balance_loss=self.using_balance_loss, condition_sets=self.condition_sets,
                                   balance_loss_weight=self.balance_loss_weight,
                                   beta_schedule=beta_schedule, n_timesteps=n_timesteps,).to(device)
            self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
            self.ema_model = copy.deepcopy(self.actor)
            if lr_decay:
                self.actor_lr_scheduler = CosineAnnealingLR(self.actor_optimizer, T_max=lr_maxt, eta_min=0.)

        hidden_dim = 256
        self.baseline = \
            nn.Sequential(nn.Linear(state_dim, hidden_dim),
                          nn.Mish(),
                          nn.Linear(hidden_dim, hidden_dim),
                          nn.Mish(),
                          nn.Linear(hidden_dim, hidden_dim),
                          nn.Mish(),
                          nn.Linear(hidden_dim, self.n_critic_quantile_points)
                          ).to(self.device)
        self.baseline_optimizer = torch.optim.Adam(self.baseline.parameters(), lr=3e-4)
        self.filter = \
            nn.Sequential(nn.Linear(state_dim, hidden_dim),
                          nn.Mish(),
                          nn.Linear(hidden_dim, hidden_dim),
                          nn.Mish(),
                          nn.Linear(hidden_dim, hidden_dim),
                          nn.Mish(),
                          # nn.Linear(hidden_dim, hidden_dim),
                          # nn.Mish(),
                          # nn.Linear(hidden_dim, hidden_dim),
                          # nn.Mish(),
                          nn.Linear(hidden_dim, self.n_filter_quantile_points)
                          ).to(self.device)
        self.filter_optimizer = torch.optim.Adam(self.filter.parameters(), lr=3e-4)
        if self.filter_policy == 'same':
            for param, target_param in zip(self.baseline.parameters(), self.filter.parameters()):
                target_param.data.copy_(param.data)

    def _cal_baseline_value(self, state, baseline, critic_loss_type):
        with torch.no_grad():
            if critic_loss_type in ['quantile-101', 'quantile-51', 'quantile-21', 'quantile-11']:
                bs_values = baseline(state)
                bs_value = torch.quantile(bs_values, self.iql_quantile_point, dim=-1, keepdim=True)
            else:
                bs_values = None
                bs_value = baseline(state)

        return bs_value, bs_values

    def _cal_target_q(self, state, action, critic_policy, using_target_critic=True):
        if critic_policy == 'behavior':
            current_q1, current_q2 = self.critic_target(state, action) \
                if using_target_critic else self.critic(state, action)
            q = torch.min(current_q1, current_q2).detach()
        elif critic_policy == 'current':
            if not self.using_guide:
                tmp_action = self.ema_model(state)
            else:
                if self.imitation_method in ['filter', 'exp_adv_weight']:
                    tmp_action = self.ema_model(state, condition=torch.ones([state.shape[0]], device=self.device))
                elif self.imitation_method == 'quantile' or self.imitation_method == 'ge_quantile':
                    tmp_action = self.ema_model(
                        state, condition=torch.ones([state.shape[0]], device=self.device)
                                         * self.iql_quantile_point)
                elif self.imitation_method == 'expectile':
                    bs_value, _ = self._cal_baseline_value(state, self.filter, self.filter_loss_type)
                    tmp_action = self.ema_model(state, condition=bs_value.squeeze(-1))
                else:
                    raise NotImplementedError
            tmp_q1, tmp_q2 = self.critic_target(state, tmp_action) \
                if using_target_critic else self.critic(state, tmp_action)
            q = torch.min(tmp_q1, tmp_q2).detach()
        else:
            raise NotImplementedError
        return q

    def _cal_bs_loss(self, state, target_q, baseline, critic_loss_type, n_quantile_points=1):
        bs_value = baseline(state)
        if critic_loss_type == 'original':
            v_critic_loss = F.mse_loss(bs_value, target_q, reduction='none')
        elif critic_loss_type == 'expectile':
            v_critic_loss = F.mse_loss(bs_value, target_q, reduction='none') * torch.abs(
                self.iql_quantile_point - (target_q < bs_value).float()).detach()
        elif critic_loss_type in ['quantile-101', 'quantile-51', 'quantile-21', 'quantile-11']:
            v_critic_loss = F.l1_loss(bs_value, target_q, reduction='none') * torch.abs(
                torch.Tensor([
                    [0.0001]+[1. / (n_quantile_points - 1) * i for i in range(1, n_quantile_points-1)]+[0.9999]
                ]).to(self.device) - (target_q < bs_value).float()).detach()
        else:
            raise NotImplementedError
        return v_critic_loss

    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):

        metric = {'target_q': [],
                  'bc_loss': [], 'ql_loss': [], 'actor_loss': [],
                  'q_critic_loss': [], 'v_critic_loss': [], 'filter_critic_loss': [], 'critic_loss': []}
        for n_iter in range(iterations):
            target_q, bc_loss, ql_loss, actor_loss, q_critic_loss, v_critic_loss, filter_critic_loss \
                = torch.zeros([1]), torch.zeros([1]), torch.zeros([1]), torch.zeros([1]), torch.zeros([1]),\
                  torch.zeros([1]), torch.zeros([1])

            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            """ Q Training """
            tmp_current_q = None
            sample_weight = None
            if self.training_epochs is None \
                    or self.training_epochs['training_critic'][0] <= \
                    self.step <= self.training_epochs['training_critic'][1]:
                current_q1, current_q2 = self.critic(state, action)
                current_q = torch.min(current_q1, current_q2).detach()
                with torch.no_grad():
                    next_value, _ = self._cal_baseline_value(next_state, self.baseline, self.critic_loss_type)
                    target_q = (reward + not_done * self.discount * next_value)
                    target_q = torch.clamp(target_q,
                                           min=min(0, replay_buffer.reward_min / (1 - self.discount)),
                                           max=max(0, replay_buffer.reward_max / (1 - self.discount)))
                q_critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
                self.critic_optimizer.zero_grad()
                q_critic_loss.backward()
                if self.grad_norm > 0:
                    nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_norm, norm_type=2)
                self.critic_optimizer.step()

                if self.n_hard_updating_target_q is None:
                    for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                        target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                else:
                    if (self.step + 1) % self.n_hard_updating_target_q == 0:
                        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                            target_param.data.copy_(param.data)

                tmp_current_q = self._cal_target_q(state, action, self.critic_policy, True)
                v_critic_loss = self._cal_bs_loss(state, tmp_current_q, self.baseline, self.critic_loss_type,
                                                  self.n_critic_quantile_points).mean()
                self.baseline_optimizer.zero_grad()
                v_critic_loss.backward()
                if self.grad_norm > 0:
                    nn.utils.clip_grad_norm_(self.baseline.parameters(), max_norm=self.grad_norm, norm_type=2)
                self.baseline_optimizer.step()

            if self.training_epochs is None \
                    or self.training_epochs['training_filter'][0] <= \
                    self.step <= self.training_epochs['training_filter'][1]:
                if self.filter_policy == 'same':
                    for param, target_param in zip(self.baseline.parameters(), self.filter.parameters()):
                        target_param.data.copy_(param.data)
                    filter_critic_loss = v_critic_loss.detach()
                else:
                    if tmp_current_q is None or self.filter_policy != self.critic_policy:
                        tmp_current_q = self._cal_target_q(state, action, self.filter_policy, True)
                    with torch.no_grad():
                        sample_weight = 1.
                        if self.balance_sample_reweight:
                            tmp_baseline, _ = self._cal_baseline_value(state, self.baseline, self.critic_loss_type)
                            tmp_better_weight = (tmp_current_q > tmp_baseline).float()
                            self.balance_sample_weight = \
                                max(1e-3, self.balance_sample_weight * 0.99 + tmp_better_weight.mean() * 0.01)
                            sample_weight = tmp_better_weight / self.balance_sample_weight + (1 - tmp_better_weight)
                    filter_critic_loss = self._cal_bs_loss(state, tmp_current_q, self.filter, self.filter_loss_type,
                                                           self.n_filter_quantile_points)
                    filter_critic_loss = (filter_critic_loss * sample_weight).mean()
                    self.filter_optimizer.zero_grad()
                    filter_critic_loss.backward()
                    if self.grad_norm > 0:
                        nn.utils.clip_grad_norm_(self.filter.parameters(), max_norm=self.grad_norm, norm_type=2)
                    self.filter_optimizer.step()

            """ Policy Training """
            if self.training_epochs is None \
                    or self.training_epochs['training_actor'][0] <= \
                    self.step <= self.training_epochs['training_actor'][1]:
                current_q = self._cal_target_q(state, action, 'behavior', True)
                if self.iql_quantile_point > 1.:
                    temp = self.iql_quantile_point
                    self.iql_quantile_point = 0.95
                    filter_value, filter_values = self._cal_baseline_value(state, self.filter, self.filter_loss_type)
                    self.iql_quantile_point = temp
                else:
                    filter_value, filter_values = self._cal_baseline_value(state, self.filter, self.filter_loss_type)
                weights = 1.
                condition = None
                bc_loss = self.actor.loss(action, state, condition=None).detach()
                with torch.no_grad():
                    adv = current_q - filter_value
                    if self.imitation_method == 'exp_adv_weight':
                        if not self.using_guide:
                            weights = torch.clamp_max(torch.exp(adv / self.weight_temperature), 100.).detach()
                        else:
                            condition = (adv > 0).float().squeeze(-1).detach()
                            if np.random.uniform() > 0.5:
                                condition = 1 - condition
                            weights = torch.exp(adv / self.weight_temperature * (condition * 2 - 1)).unsqueeze(-1)
                            weights = torch.clamp_max(weights, 100.).detach()
                    elif self.imitation_method == 'filter':
                        if not self.using_guide:
                            weights = (adv > 0).float().detach()
                            if np.random.rand() < 0.001:
                                print('weights mean', weights.mean().item())
                        else:
                            condition = (adv > 0).float().squeeze(-1).detach()
                            if np.random.rand() < 0.001:
                                print('condition mean', condition.mean().item())
                            if self.balance_condition_reweight:
                                weights = ((adv > 0).float() / (1 - self.iql_quantile_point)
                                           + (adv > 0).float() / self.iql_quantile_point).detach()
                    elif self.imitation_method == 'quantile':
                        assert self.using_guide and self.filter_loss_type in ['quantile-101', 'quantile-51', 'quantile-21', 'quantile-11']
                        condition = (current_q > filter_values).float().sum(dim=-1)
                        condition = torch.clamp_min(condition - 1., min=0.).detach()
                    elif self.imitation_method == 'ge_quantile':
                        assert self.using_guide and self.filter_loss_type in ['quantile-101', 'quantile-51', 'quantile-21', 'quantile-11']
                        tmp_quantiles = (current_q > filter_values).float().sum(dim=-1).detach()
                        condition = torch.floor(torch.rand_like(tmp_quantiles) * tmp_quantiles)
                    elif self.imitation_method == 'quantile_weight':
                        assert self.using_guide and self.filter_loss_type in ['quantile-101', 'quantile-51', 'quantile-21', 'quantile-11']
                        if np.random.rand() < 0.5:
                            condition = torch.ones([batch_size]).to(self.device)
                            weights = (current_q > filter_values).float().sum(dim=-1, keepdim=True).detach()
                        else:
                            condition = torch.zeros([batch_size]).to(self.device)
                            weights = (current_q < filter_values).float().sum(dim=-1, keepdim=True).detach()
                    elif self.imitation_method == 'expectile':
                        assert self.using_guide
                        condition = filter_value.squeeze(-1)
                    else:
                        raise NotImplementedError
                    if self.imitation_method == 'quantile' or self.imitation_method == 'ge_quantile':
                        if self.balance_condition_reweight:
                            range_0_100 = torch.range(0, self.n_filter_quantile_points-1, device=self.device)
                            condition_number = (range_0_100.unsqueeze(-1) == condition.unsqueeze(0)).sum(-1) / batch_size
                            self.condition_running_average = self.condition_running_average * self.condition_running_ratio \
                                                             + condition_number * (1 - self.condition_running_ratio)
                            weights = torch.clamp(1. / (1e-6 + self.condition_running_average), min=0.01, max=100.)
                            weights = weights[condition.long()].unsqueeze(-1)
                            weights = weights / weights.mean()
                        elif self.balance_sample_reweight:
                            tmp_baseline, _ = self._cal_baseline_value(state, self.baseline, self.critic_loss_type)
                            tmp_better_weight = (current_q > tmp_baseline).float()
                            sample_weight = tmp_better_weight / self.balance_sample_weight + (1 - tmp_better_weight)
                            weights = sample_weight
                        condition = condition / (self.n_filter_quantile_points - 1)
                ql_loss = self.actor.loss(action, state, condition=condition, weights=weights)
                actor_loss = ql_loss

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                if self.grad_norm > 0:
                    nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
                self.actor_optimizer.step()

                """ Step Target network """
                if self.step % self.update_ema_every == 0:
                    self.step_ema()

            self.step += 1

            """ Log """
            metric['target_q'].append(target_q.mean().item())
            metric['bc_loss'].append(bc_loss.item())
            metric['ql_loss'].append(ql_loss.item())
            metric['actor_loss'].append(actor_loss.item())
            metric['q_critic_loss'].append(q_critic_loss.item())
            metric['v_critic_loss'].append(v_critic_loss.item())
            metric['filter_critic_loss'].append(filter_critic_loss.item())
            metric['critic_loss'].append((q_critic_loss + v_critic_loss).item())

        if self.lr_decay:
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()

        return metric

    def sample_action(self, state):
        state = torch.FloatTensor(state.reshape(-1, self.state_dim)).to(self.device)
        n_batch = state.shape[0]
        state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
        with torch.no_grad():
            if self.using_guide:
                if self.imitation_method == 'filter' or self.imitation_method == 'exp_adv_weight':
                    action = self.actor.sample(state_rpt,
                                               condition=torch.ones([state_rpt.shape[0]], device=self.device))
                elif self.imitation_method == 'quantile' or self.imitation_method == 'ge_quantile':
                    action = self.actor.sample(
                        state_rpt, condition=torch.ones([state_rpt.shape[0]], device=self.device)
                                             * self.iql_quantile_point
                    )
                elif self.imitation_method == 'quantile_weight':
                    action = self.actor.sample(state_rpt,
                                               condition=torch.ones([state_rpt.shape[0]], device=self.device))
                elif self.imitation_method == 'expectile':
                    bs_value, _ = self._cal_baseline_value(state_rpt, self.baseline, self.critic_loss_type)
                    action = self.actor.sample(state_rpt, condition=bs_value.squeeze(-1))
                else:
                    raise NotImplementedError
            else:
                action = self.actor.sample(state_rpt)
            q_value = self.critic.q_min(state_rpt, action).flatten()
            idx = torch.multinomial(F.softmax(q_value.reshape(n_batch, -1), dim=-1), 1)
            action = torch.gather(action.reshape(n_batch, -1, self.action_dim),
                                  dim=1,
                                  index=idx.unsqueeze(-1).expand([n_batch, 1, self.action_dim]))
            action = action.squeeze(1)
            action = action.cpu().numpy()
        return action

    def save_model(self, dir, id=''):
        import os
        if not os.path.exists(dir):
            os.makedirs(dir)

        torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
        torch.save(self.critic_optimizer.state_dict(), f'{dir}/critic_optimizer_{id}.pth')
        torch.save(self.critic_target.state_dict(), f'{dir}/critic_target_{id}.pth')

        torch.save(self.baseline.state_dict(), f'{dir}/baseline_{id}.pth')
        torch.save(self.baseline_optimizer.state_dict(), f'{dir}/baseline_optimizer_{id}.pth')

        torch.save(self.filter.state_dict(), f'{dir}/filter_{id}.pth')
        torch.save(self.filter_optimizer.state_dict(), f'{dir}/filter_optimizer_{id}.pth')

        torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
        torch.save(self.actor_optimizer.state_dict(), f'{dir}/actor_optimizer_{id}.pth')
        torch.save(self.ema_model.state_dict(), f'{dir}/ema_model_{id}.pth')

    def load_model(self, dir, id=''):
        self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth', map_location=self.device))
        self.critic_optimizer.load_state_dict(torch.load(f'{dir}/critic_optimizer_{id}.pth', map_location=self.device))
        self.critic_target.load_state_dict(torch.load(f'{dir}/critic_target_{id}.pth', map_location=self.device))

        self.baseline.load_state_dict(torch.load(f'{dir}/baseline_{id}.pth', map_location=self.device))
        self.baseline_optimizer.load_state_dict(torch.load(f'{dir}/baseline_optimizer_{id}.pth', map_location=self.device))

        self.filter.load_state_dict(torch.load(f'{dir}/filter_{id}.pth', map_location=self.device))
        self.filter_optimizer.load_state_dict(torch.load(f'{dir}/filter_optimizer_{id}.pth', map_location=self.device))

        self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth', map_location=self.device))
        self.actor_optimizer.load_state_dict(torch.load(f'{dir}/actor_optimizer_{id}.pth', map_location=self.device))
        self.ema_model.load_state_dict(torch.load(f'{dir}/ema_model_{id}.pth', map_location=self.device))

